{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train a Model for Detecting Cars\n",
    "\n",
    "[![image](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/opengeos/geoai/blob/main/docs/examples/train_car_detection.ipynb)\n",
    "\n",
    "## Install package\n",
    "To use the `geoai-py` package, ensure it is installed in your environment. Uncomment the command below if needed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %pip install geoai-py"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import geoai"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Download sample data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_raster_url = (\n",
    "    \"https://huggingface.co/datasets/giswqs/geospatial/resolve/main/cars_7cm.tif\"\n",
    ")\n",
    "train_vector_url = \"https://huggingface.co/datasets/giswqs/geospatial/resolve/main/car_detection.geojson\"\n",
    "test_raster_url = (\n",
    "    \"https://huggingface.co/datasets/giswqs/geospatial/resolve/main/cars_test_7cm.tif\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_raster_path = geoai.download_file(train_raster_url)\n",
    "train_vector_path = geoai.download_file(train_vector_url)\n",
    "test_raster_path = geoai.download_file(test_raster_url)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualize sample data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "geoai.view_vector_interactive(train_vector_path, tiles=train_raster_url)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "geoai.view_raster(test_raster_url)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create training data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out_folder = \"output\"\n",
    "tiles = geoai.export_geotiff_tiles(\n",
    "    in_raster=train_raster_path,\n",
    "    out_folder=out_folder,\n",
    "    in_class_data=train_vector_path,\n",
    "    tile_size=512,\n",
    "    stride=256,\n",
    "    buffer_radius=0,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train object detection model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "geoai.train_MaskRCNN_model(\n",
    "    images_dir=f\"{out_folder}/images\",\n",
    "    labels_dir=f\"{out_folder}/labels\",\n",
    "    output_dir=f\"{out_folder}/models\",\n",
    "    num_channels=3,\n",
    "    pretrained=True,\n",
    "    batch_size=4,\n",
    "    num_epochs=100,\n",
    "    learning_rate=0.005,\n",
    "    val_split=0.2,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "masks_path = \"cars_prediction.tif\"\n",
    "model_path = f\"{out_folder}/models/best_model.pth\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "geoai.object_detection(\n",
    "    test_raster_path,\n",
    "    masks_path,\n",
    "    model_path,\n",
    "    window_size=512,\n",
    "    overlap=256,\n",
    "    confidence_threshold=0.5,\n",
    "    batch_size=4,\n",
    "    num_channels=3,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Vectorize masks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_path = \"cars_prediction.geojson\"\n",
    "gdf = geoai.orthogonalize(masks_path, output_path, epsilon=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualize results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "geoai.view_vector_interactive(output_path, tiles=test_raster_url)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "geoai.create_split_map(\n",
    "    left_layer=output_path,\n",
    "    right_layer=test_raster_url,\n",
    "    left_args={\"style\": {\"color\": \"red\", \"fillOpacity\": 0.2}},\n",
    "    basemap=test_raster_url,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![image](https://github.com/user-attachments/assets/31f4eb34-b829-423d-a9a6-86fd79219770)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
